Imports and definitions¶

In [37]:
import pandas as pd
import torch
import torchvision.transforms.functional as transform
import torchvision.transforms.functional as F
from EnsembleXAI import Ensemble, Metrics
from torchvision.transforms import Resize, CenterCrop
import os
from PIL import Image
from torchvision.models import resnet50, ResNet50_Weights
import urllib.request
import json
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
from captum.attr import IntegratedGradients, Occlusion, NoiseTunnel, visualization as viz, Saliency
import matplotlib.pyplot as plt
In [2]:
with urllib.request.urlopen("https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json") as url:
    imagenet_classes_dict = json.load(url)
In [3]:
def download_class_images(class_id, masks_path):
    full_path = masks_path + class_id + "\\"
    kaggle_path = f"/ILSVRC/Data/CLS-LOC/train/{class_id}/"
    for file_name in os.listdir(full_path):
        file_name_jpeg = file_name[:-3] + 'JPEG'
        !kaggle competitions download -f {kaggle_path}{file_name_jpeg} -p ./images/{class_id}/ -c imagenet-object-localization-challenge


def download():
    skipped = []
    for class_id in os.listdir(masks_dir):
        image_class_path = os.path.join(images_dir, class_id)
        if os.path.exists(image_class_path) and len(os.listdir(image_class_path)) == 10:
            skipped.append(class_id)
            continue
        #download_class_images(class_id, masks_dir)
        print(f"Downloaded {class_id}")
    print("Full dirs: " + str(len(skipped)))


def images_list(image_path, resize=True):
    _crop = CenterCrop(224).forward
    _resize = Resize([232,232]).forward
    images = []
    for image_name in os.listdir(image_path):
        image = Image.open(image_path + image_name)
        if resize:
            image = _crop(_resize(image))
        images.append(image)
    return images

def dict_to_matrix(original_data, explanations_dict, predictor, masks_tensor):
    df = pd.DataFrame()
    #impact_thresh = 0.4
    #accordance_thresh = 0.2
    for key, value in explanations_dict.items():
        for thresh in range(10):
            df.loc[key, f"Decision Impact Ratio{thresh}"] = Metrics.decision_impact_ratio(original_data, predictor, value, thresh/10, 0)
            df.loc[key, f"Confidence Impact Ratio Same{thresh}"] = Metrics.confidence_impact_ratio(original_data, predictor, value, thresh/10, 0, compare_to="same_prediction")
            df.loc[key, f"CIR Max{thresh}"] = Metrics.confidence_impact_ratio(original_data, predictor, value, thresh/10, 0, compare_to="new_prediction")
            df.loc[key, f"Average Recall{thresh}"] = torch.mean(Metrics.accordance_recall(value, masks_tensor, thresh/10)).item()
            df.loc[key, f"Average Precision{thresh}"] = torch.mean(Metrics.accordance_precision(value, masks_tensor, thresh/10)).item()
        df.loc[key, "F1_score"] = Metrics.F1_score(explanations_dict[key], masks_tensor)
        df.loc[key, "IOU"] = Metrics.intersection_over_union(explanations_dict[key], masks_tensor)
    return df

Images load¶

In [4]:
input_dir = "\\".join(os.getcwd().split(sep="\\")[:-2] + ['input'])
masks_dir = input_dir + f'\\ImageNetS50\\train-semi-segmentation\\'
images_dir = os.getcwd() + "\\images\\"
In [5]:
print(os.listdir(images_dir))
['expl_n01491361.pickle', 'n01443537', 'n01491361', 'n01491361.png', 'n01531178', 'n01644373', 'n02104029', 'n02119022', 'n02123597', 'n02133161', 'n02165456', 'n02281406', 'n02325366', 'n02342885', 'n02396427', 'n02483362', 'n02504458', 'n02510455', 'n02690373', 'n02747177', 'n02783161', 'n02814533', 'n02859443', 'n02917067', 'n02992529', 'n03014705', 'n03047690', 'n03095699', 'n03197337', 'n03201208', 'n03445777', 'n03452741', 'n03584829', 'n03630383', 'n03775546', 'n03791053', 'n03874599', 'n03891251', 'n04026417', 'n04335435', 'n04380533', 'n04404412', 'n04447861', 'n04507155', 'n04522168', 'n04557648', 'n04562935', 'n04612504', 'n06794110', 'n07749582', 'n07831146', 'n12998815']
In [6]:
id = "n01491361"
def load_all(classid):
    all_img = images_list(images_dir + classid + "\\")
    all_img_org = images_list(images_dir + classid + "\\", resize=False)
    all_tens = [F.to_tensor(img) for img in all_img]
    all_msks = [(F.to_tensor(img)>0).float() for img in images_list(masks_dir + classid + "\\")]
    tens_img = torch.stack(all_tens)
    tens_msks = torch.stack(all_msks)[:,0].unsqueeze(dim=1).repeat(1, tens_img.shape[1], 1, 1)
    return all_img, all_img_org, all_tens, all_msks, tens_img, tens_msks
all_images, all_images_original, all_tensors, all_masks, tensor_images, tensor_masks = load_all(id)
In [7]:
photos = []
for tensor, mask in zip(all_tensors, all_masks):
    photo = torch.cat([tensor, mask], dim=2)
    photos.append(photo)
display(transform.to_pil_image(torch.cat(photos, dim=1)))

Model Loading¶

In [8]:
model = resnet50(weights=ResNet50_Weights.DEFAULT)
model.eval()
resnet_transform = ResNet50_Weights.DEFAULT.transforms()
pipeline = lambda images: torch.stack([resnet_transform(image) for image in images])
proper_data = pipeline(all_images_original)
In [9]:
outputs2 = model(proper_data)
_, preds2 = torch.max(outputs2, 1)
probs2 = torch.nn.functional.softmax(outputs2, dim=1)
[imagenet_classes_dict[str(i.item())][1] for i in preds2] # gar = Niszczukokształtne
Out[9]:
['tiger_shark',
 'tiger_shark',
 'tiger_shark',
 'great_white_shark',
 'tiger_shark',
 'tiger_shark',
 'tiger_shark',
 'tiger_shark',
 'hammerhead',
 'gar']

Single Explanations¶

In [10]:
single_pred = preds2[2].unsqueeze(dim=0)
single_data = proper_data[2].unsqueeze(dim=0)
integrated_gradients = IntegratedGradients(model)
attributions_ig = integrated_gradients.attribute(single_data, target=single_pred, n_steps=200)

Basing on: https://captum.ai/tutorials/Resnet_TorchVision_Interpret

In [11]:
transformed_img = resnet_transform(all_images_original[2])
default_cmap = LinearSegmentedColormap.from_list('custom blue',
                                                 [(0, '#ffffff'),
                                                  (0.25, '#000000'),
                                                  (1, '#000000')], N=256)

_ = viz.visualize_image_attr(np.transpose(attributions_ig.squeeze().cpu().detach().numpy(), (1,2,0)),
                             np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
                             method='heat_map',
                             cmap=default_cmap,
                             show_colorbar=True,
                             sign='positive',
                             outlier_perc=1)
display(all_images[2])
In [12]:
import gc
gc.collect()
Out[12]:
190
In [13]:
noise_tunnel = NoiseTunnel(integrated_gradients)

attributions_ig_nt = noise_tunnel.attribute(single_data, nt_samples=5, nt_type='smoothgrad_sq', target=single_pred)
In [14]:
attributions_ig_nt_all = torch.cat([noise_tunnel.attribute(tensor_images[i].unsqueeze(dim=0), nt_samples=5, nt_type='smoothgrad_sq', target=preds2[i].unsqueeze(dim=0)) for i in range(10)], dim=0)
In [15]:
_ = viz.visualize_image_attr_multiple(np.transpose(attributions_ig_nt.squeeze().numpy(), (1,2,0)),
                                      np.array(all_images[2]),
                                      ["heat_map", "original_image"],
                                      ["positive", "all"],
                                      cmap=default_cmap,
                                      show_colorbar=True)
In [16]:
occlusion = Occlusion(model)

attributions_occ = occlusion.attribute(single_data,
                                       strides = (3, 8, 8),
                                       target=single_pred,
                                       sliding_window_shapes=(3, 15, 15),
                                       baselines=0)
In [17]:
_ = viz.visualize_image_attr_multiple(np.transpose(attributions_occ.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      np.array(all_images[2]),
                                      ["heat_map", "original_image"],
                                      ["positive", "all"],
                                      show_colorbar=True,
                                      outlier_perc=2,
                                      )
In [18]:
occlusion = Occlusion(model)

attributions_occ2 = occlusion.attribute(single_data,
                                       strides = (3, 20, 20),
                                       target=single_pred,
                                       sliding_window_shapes=(3, 25, 25),
                                       baselines=0)
_2 = viz.visualize_image_attr_multiple(np.transpose(attributions_occ2.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      np.array(all_images[2]),
                                      ["heat_map", "original_image"],
                                      ["all", "positive"],
                                      show_colorbar=True,
                                      outlier_perc=2,
                                      )
In [19]:
occlusion = Occlusion(model)

attributions_occ_all_25 = occlusion.attribute(tensor_images,
                                        strides = (3, 20, 20),
                                        target = preds2,
                                        sliding_window_shapes = (3, 25, 25),
                                        baselines=0)

attributions_occ_all_15 = occlusion.attribute(tensor_images,
                                           strides = (3, 8, 8),
                                           target = preds2,
                                           sliding_window_shapes = (3, 15, 15),
                                           baselines=0)
In [92]:
saliency = Saliency(model)
attr_saliency = saliency.attribute(tensor_images, target=preds2)
In [107]:
_3 = viz.visualize_image_attr_multiple(np.transpose(attr_saliency[2].numpy(), (1,2,0)),
                                       np.array(all_images[2]),
                                       ["heat_map", "original_image"],
                                       ["positive", "positive"],
                                       show_colorbar=True,
                                       outlier_perc=2,
                                       )
In [20]:
def sample_xai(images):
    if images.shape[0] == 1:
        target = single_pred
    else:
        target = single_pred.repeat(images.shape[0])
    xai = occlusion.attribute(images,
                              strides = (3, 40, 40),
                              target=target,
                              sliding_window_shapes=(3,50, 50),
                              baselines=0)
    return xai
Metrics.stability(sample_xai, single_data.squeeze(dim=0), single_data.repeat(10,1,1,1))
Out[20]:
0.0

Ensembles¶

In [118]:
x = torch.cat([attributions_occ, attributions_ig_nt])
aggregated1 = Ensemble.basic(x, aggregating_func='avg')
aggregated2 = Ensemble.basic(x, aggregating_func='min')
aggregated3 = Ensemble.basic(x, aggregating_func='max')
In [23]:
#display(transform.to_pil_image(aggregated1[0]))
_ = viz.visualize_image_attr_multiple(np.transpose(aggregated1.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      np.array(all_images[2]),
                                      ["heat_map", "original_image", "masked_image"],
                                      ["all", "positive", "positive"],
                                      show_colorbar=True,
                                      outlier_perc=2,
                                      )
In [24]:
all_stacked = torch.stack([attributions_occ_all_15, attributions_occ_all_25], dim=1)
In [119]:
attr_agg_avg = Ensemble.basic(all_stacked, aggregating_func='avg')
attr_agg_min = Ensemble.basic(all_stacked, aggregating_func='min')
attr_agg_max = Ensemble.basic(all_stacked, aggregating_func='max')
In [25]:
plt.style.use('fast')
In [161]:
def plot_explanations(images, explanations_dict, classes_predicted,
                      #cmaps=[default_cmap, None, None, None, None, None, None],
                      method = "heat_map"):
    nrow, ncol = len(images), len(explanations_dict.keys())+1
    fig, ax = plt.subplots(nrows = nrow, ncols = ncol, figsize=(14, 3*nrow))
    columns_names = ["Original"] + list(explanations_dict.keys())
    for col, col_name in zip(ax[0], columns_names):
        col.title.set_text(col_name)
    for i, img in enumerate(images):
        ax[i,0].xaxis.set_ticks_position("none")
        ax[i,0].yaxis.set_ticks_position("none")
        ax[i,0].set_yticklabels([])
        ax[i,0].set_xticklabels([])
        ax[i,0].imshow(np.array(images[i]), vmin=0, vmax=255)
        ax[i,0].set_ylabel(classes_predicted[i], size='large')
        for j, (col, (key, explanations)) in enumerate(zip(ax[i,1:], explanations_dict.items())):
            #ith image, jth explanation
            #expl = explanations[i,j]
            expl = explanations[j]
            sign = "all"
            cmap=None
            if expl.amin() >= 0:
                sign = "positive"
                cmap = default_cmap
            _ = viz.visualize_image_attr(np.transpose(expl.squeeze().numpy(), (1,2,0)),
                                         original_image=np.array(img),
                                         method=method,
                                         sign=sign,
                                         plt_fig_axis=(fig, col),
                                         show_colorbar=True,
                                         outlier_perc=2,
                                         cmap=cmap,
                                         use_pyplot=False
                                         )
    plt.savefig(f"images/{id}.png")
    plt.show()
In [162]:
expl_dict = {"Gradients":attributions_ig_nt_all, "Saliency":attr_saliency, "Occlusion 25":attributions_occ_all_25,
             "Occlusion 15":attributions_occ_all_15, "Max Aggregate":attr_agg_max,
             "Min Aggregate":attr_agg_min, "Avg Aggregate":attr_agg_avg}
explanations_three = torch.cat([all_stacked, attr_saliency.unsqueeze(dim=1), attributions_occ_all_15.unsqueeze(dim=1),attr_agg_max.unsqueeze(dim=1), attr_agg_min.unsqueeze(dim=1), attr_agg_avg.unsqueeze(dim=1)], dim=1)
predicted_names = [imagenet_classes_dict[str(i.item())][1] for i in preds2]
plot_explanations(all_images, expl_dict, predicted_names, method="blended_heat_map")
In [121]:
predict = lambda x: torch.nn.Softmax(dim=0)(model(x))
dict_to_matrix(proper_data, expl_dict, predict, tensor_masks)
Out[121]:
Decision Impact Ratio0 Confidence Impact Ratio Same0 CIR Max0 Average Recall0 Average Precision0 Decision Impact Ratio1 Confidence Impact Ratio Same1 CIR Max1 Average Recall1 Average Precision1 ... CIR Max8 Average Recall8 Average Precision8 Decision Impact Ratio9 Confidence Impact Ratio Same9 CIR Max9 Average Recall9 Average Precision9 F1_score IOU
Gradients 1.0 0.292043 0.292043 1.000000 0.239076 0.0 0.000000 0.000000 0.000000 0.000000 ... 0.000000 0.000000 0.000000 0.0 0.000000 0.000000 0.000000 0.000000 0.358053 0.000000
Saliency 1.0 0.292043 0.292043 1.000000 0.239318 0.8 0.130759 -0.075646 0.200246 0.477015 ... -0.009288 0.000763 0.428025 0.2 0.016281 -0.000817 0.000368 0.318214 0.358287 0.004982
Occlusion 25 0.8 0.126705 -0.086325 0.567629 0.260502 0.8 0.084125 -0.002501 0.266416 0.432520 ... -0.002693 0.000890 0.100000 0.0 0.000000 0.000000 0.000000 0.000000 0.302462 0.023134
Occlusion 15 1.0 0.243223 -0.117268 0.482177 0.238699 0.9 0.135679 -0.027287 0.143442 0.459755 ... 0.000890 0.000057 0.100000 0.0 0.000000 0.000000 0.000000 0.000000 0.277534 0.004646
Max Aggregate 1.0 0.237470 -0.106383 0.668883 0.246700 0.8 0.157906 -0.039502 0.305328 0.435404 ... -0.001771 0.000947 0.200000 0.0 0.000000 0.000000 0.000000 0.000000 0.314795 0.024842
Min Aggregate 0.9 0.168878 -0.132847 0.380922 0.254946 0.5 0.097367 0.023040 0.104529 0.442264 ... 0.000000 0.000000 0.000000 0.0 0.000000 0.000000 0.000000 0.000000 0.256129 0.002419
Avg Aggregate 1.0 0.185669 -0.134782 0.536309 0.250199 0.7 0.092850 -0.034366 0.192898 0.436170 ... 0.000000 0.000000 0.000000 0.0 0.000000 0.000000 0.000000 0.000000 0.291462 0.006486

7 rows × 52 columns

In [183]:
per_image = pd.DataFrame()
per_image["Average Recall"] = torch.mean(torch.stack([Metrics.accordance_recall(expl, tensor_masks, 0.2) for expl in expl_dict.values()]), dim=0).numpy()
per_image["Average Precision"] = torch.mean(torch.stack([Metrics.accordance_precision(expl, tensor_masks, 0.2) for expl in expl_dict.values()]), dim=0).numpy()
per_image["Consistency"] = [Metrics.consistency(expls) for expls in torch.stack(list(expl_dict.values()), dim=1)]
per_image
Out[183]:
Average Recall Average Precision Consistency
0 0.107373 0.373676 0.015499
1 0.029311 0.664992 0.025282
2 0.004816 0.371717 0.026888
3 0.335163 0.201642 0.018533
4 0.052924 0.847267 0.027571
5 0.004855 0.056777 0.023819
6 0.012087 0.446683 0.047301
7 0.000784 0.045918 0.033685
8 0.047417 0.172541 0.022880
9 0.189759 0.348346 0.020937
In [184]:
import pickle
a=expl_dict
with open(f'images/expl_{id}.pickle', 'wb') as handle:
    pickle.dump(a, handle, protocol=pickle.HIGHEST_PROTOCOL)

with open(f'images/expl_{id}.pickle', 'rb') as handle:
    b = pickle.load(handle)
b.keys()
Out[184]:
dict_keys(['Gradients', 'Saliency', 'Occlusion 25', 'Occlusion 15', 'Max Aggregate', 'Min Aggregate', 'Avg Aggregate'])
In [ ]: